import json
import algorithm

class LinearAttack:
    def __init__(self):
        self.SPNenc = algorithm.Encrpt() # 基本的加密算法
        self.linear_table = [[-8 for i in range(16)] for j in range(16)]
        json_obj = json.load(open('Plain_Cipher_pairs.json','r'))
        self.PC_pairs = json_obj['PC_pairs']
        # linear_expression:攻击使用的线性表达式，用包含两个16比特的整数表示
        # 下面的取值就表示：P5 ^ P7 ^ P8 ^ U6 ^ U8 ^ U14 ^ U16 = 0
        self.linear_expression = [0b0000101100000000, 0b0000010100000101]
        self.tpk_count = [-5000 for _ in range(256)] # 每个target partial subkey的计数器

    # 线性近似表的生成
    def gen_linear_table(self):
        # 生成线性近似表，即论文中的Table 4
        for input_sum in range(16):
            for output_sum in range(16):
                for x in range(16):
                    y = self.SPNenc.sbox(x)
                    if self.check_linear(input_sum,output_sum,x,y,4):
                        self.linear_table[input_sum][output_sum] += 1

    def check_linear(self,input_sum,output_sum,x,y,length):
        """
        判断输入x和输出y是否满足线性近似的表达式，表达式根据input_sum和output_sum 给出
        :param input_sum: length比特位的整数，表示输入x中哪些位参与到线性近似的表达式中
        :param output_sum: length比特位的整数，表示输出y中哪些位参与到线性近似的表达式中
        :param x: 输入x
        :param y: 输出y
        :param length: x与y的比特长度
        :return: True/False
        """
        # e.g. input_sum = 0b0001 , output_sum = 0b1110, x = 0b0101, y = 0b0011 ,length = 4
        # 线性近似的式子根据input_sum和output_sum确定为： x3 ^ y0 ^ y1 ^ y2 = 0
        # 带入x和y的值发现表达式不成立，返回False

        # 该过程包含一点小推导，不是直观实现方式
        assist_num = (input_sum & x) ^ (output_sum & y)
        ret = 0
        for _ in range(length):
            ret ^= (assist_num & 1)
            assist_num >>= 1
        if ret == 0:
            return True
        else:
            return False

    def dec_partial(self,ciphertext,rk):
        # 对密文进行部分解密
        beforeMixKey = ciphertext ^ rk
        beforeSbox = self.SPNenc.sbox_inv(beforeMixKey)
        return beforeSbox

    def linear_attack(self):
        # 线性攻击生成第五轮目标子密钥
        input_sum,output_sum = self.linear_expression
        for plain,cipher in self.PC_pairs:
            for tpk in range(256):
                # 将部分子密钥补充为完整密钥
                subkey = (((tpk >> 4) & 0xf) << 8) ^ (tpk & 0xf)
                U = self.dec_partial(cipher,subkey)
                if self.check_linear(input_sum,output_sum,plain,U,16):
                    self.tpk_count[tpk] += 1

        # 寻找count对应偏差最大的子密钥作为攻击结果
        max, max_ind = 0,0
        for tpk in range(256):
            self.tpk_count[tpk] = abs(self.tpk_count[tpk])
            self.tpk_count[tpk] /= 10000
            if max < self.tpk_count[tpk]:
                max = self.tpk_count[tpk]
                max_ind = tpk
        print("目标部分子密钥分别为：",hex(max_ind>>4),hex(max_ind & 0xf))
        print("实际第五轮子密钥为：",hex(self.SPNenc.rk[4]))


if __name__ == "__main__":
    attack = LinearAttack()

    # 测试生成线性近似表
    attack.gen_linear_table()
    print("线性近似表的生成结果：")
    for row in attack.linear_table:
        print(row)

    # 测试线性攻击能否恢复target partial key
    attack.linear_attack()
